import json
from matplotlib import pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import time
import os
import re
from tqdm import tqdm
from colorama import Fore
from main import get_args

from gen_plot import DataUtils, STATIC


plt.rcParams["font.sans-serif"] = ["SimHei"]  ###解决中文乱码
plt.rcParams["axes.unicode_minus"] = False


def read_json(file_path):
    with open(file_path, "r") as f:
        data = json.load(f)
    return data


def read_save_info(args):

    if args.task_name != "":
        cur_path = f"output/Net-{args.net_idx}/{args.mode}/{args.evaluator}/{args.mutation}/{args.task_name}"
    else:

        cur_path = (
            f"output/Net-{args.net_idx}/{args.mode}/{args.evaluator}/{args.mutation}"
        )

    label = f"{args.bee_num}_{args.max_itrs}"
    return read_json(f"{cur_path}/{label}_save_info.json")


# a function to 比较不同的实验的cost下降结果
def compare_cost():

    args = get_args()
    args.bee_num = 20
    args.max_itrs = 50
    args.net_idx = 1
    plt.figure(figsize=(20, 10))
    for dim_num in ["D1", "D2", "D3", "D4", "D5", "D6"]:
        # for loss in ["k2", "bic", "bde"]:
        for loss in ["bic"]:
            # for mutation in ["mutation", "cross_and_mutation", "cross"]:
            for mutation in ["mutation"]:
                args.mode = dim_num
                args.evaluator = loss
                args.mutation = mutation
                data = read_save_info(args)
                cost = data["convergence"]["best"]
                plt.plot(cost, label=f"{dim_num}_{loss}_{mutation}")
    plt.title(f"Net-{args.net_idx} cost comparison")
    plt.legend()
    cur_dir = f"output/analysis/Net-{args.net_idx}"
    if not os.path.exists(cur_dir):
        os.makedirs(cur_dir)
    print(f"saving to {Fore.GREEN}{cur_dir}/cost_comparison.png{Fore.RESET}")
    plt.savefig(f"{cur_dir}/cost_comparison.png")

    plt.close()


def extract_score(data, score_name):
    """
    输入文本是这样的 ： "TP=1, TN=6, FP=69, FN=14, Precision=0.014285714285714285, Recall=0.06666666666666667, F1=0.023529411764705882",
    要从中抽取TP,TN,FP,FN, Precision, Recall, F1等指标
    """
    TP = re.findall(r"TP=(\d+)", data)
    TN = re.findall(r"TN=(\d+)", data)
    FP = re.findall(r"FP=(\d+)", data)
    FN = re.findall(r"FN=(\d+)", data)
    Precision = re.findall(r"Precision=(\d+.\d+)", data)
    Recall = re.findall(r"Recall=(\d+.\d+)", data)
    F1 = re.findall(r"F1=(\d+.\d+)", data)
    if F1:
        return {
            "TP": int(TP[0]),
            "TN": int(TN[0]),
            "FP": int(FP[0]),
            "FN": int(FN[0]),
            "Precision": float(Precision[0]),
            "Recall": float(Recall[0]),
            "F1": float(F1[0]),
        }
    else:
        return {
            "TP": int(TP[0]),
            "TN": int(TN[0]),
            "FP": int(FP[0]),
            "FN": int(FN[0]),
            "Precision": float(Precision[0]),
            "Recall": float(Recall[0]),
            "F1": 0.0,
        }


def compare_F1():
    args = get_args()
    args.net_idx = 1
    args.bee_num = 20
    args.max_itrs = 50
    plt.figure(figsize=(20, 10))
    f1_scores = []  # List to store F1 scores
    for dim_num in ["D1", "D2", "D3", "D4", "D5", "D6"]:
        # for loss in ["k2", "bic", "bde"]:
        for loss in ["bic"]:
            # for mutation in ["mutation", "cross_and_mutation", "cross"]:
            for mutation in ["mutation"]:
                args.mode = dim_num
                args.evaluator = loss
                args.mutation = mutation
                data = read_save_info(args)
                evaluation_data = data["union_eval_result"]
                F1 = extract_score(evaluation_data, "F1")
                f1_scores.append(
                    {
                        "label": f"{dim_num}_{loss}_{mutation}",  # Append label to the list
                        "score": F1,
                    }
                )  # Append F1 score to the list
    # 根据F1 score排序，绘制条形图

    f1_scores = sorted(f1_scores, key=lambda x: x["score"]["F1"], reverse=True)
    labels = [x["label"] for x in f1_scores]
    scores = [x["score"]["F1"] for x in f1_scores]
    plt.bar(labels, scores)
    plt.title(f"Net-{args.net_idx} F1 score comparison")
    plt.xticks(rotation=45, fontsize=16)
    plt.tight_layout()
    cur_dir = f"output/analysis/Net-{args.net_idx}"
    if not os.path.exists(cur_dir):
        os.makedirs(cur_dir)
    print(f"saving to {Fore.GREEN}{cur_dir}/F1_comparison.png{Fore.RESET}")
    plt.savefig(f"{cur_dir}/F1_comparison.png")
    plt.close()


def steps_compare(size, normal_list, faster_list, show=True, save_dir=None):
    color_1, color_2, color_3, color_4 = "#FFA07A", "#20B2AA", "#FF6347", "#4682B4"
    # 绘制柱状图，比较size10的情况下加速和不加度的步数，每个net分别绘制一个柱子，并且加上平均值
    plt.figure(figsize=(6, 5))
    x = np.arange(1, 6)
    plt.bar(x, normal_list, label="不使用DTW初始化", width=0.5, color=color_1)
    plt.bar(x, faster_list, label="使用DTW初始化", width=0.5, color=color_2)
    plt.axhline(
        y=np.mean(normal_list), color=color_1, linestyle="--", label="不使用DTW的平均值"
    )
    plt.axhline(
        y=np.mean(faster_list),
        color=color_4,
        linestyle="--",
        label="使用DTW的平均值",
    )
    plt.xticks(x, [f"Net-{i}" for i in range(1, 6)])
    plt.ylabel("迭代次数")
    plt.title(f"收敛所需的迭代次数比较（Size-{size}）")
    plt.legend(loc="center left", bbox_to_anchor=(1, 0.5), frameon=False)
    plt.tight_layout()
    if save_dir is not None:
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        size_path = f"{save_dir}/steps_compare_size_{size}.png"
        plt.savefig(size_path)
        print(f"{Fore.BLUE}save to {size_path}{Fore.RESET}")
    if show:
        plt.show()

    pass


def compareDTW_steps(show=True, save_dir=None):
    step_lists = {
        "size_10": [],
        "size_100": [],
        "size_10_faster": [],
        "size_100_faster": [],
    }
    for idx in range(1, 6):
        best_list, mean_list, std_list = DataUtils.read_cost_data(
            idx, 10, scale_factor=1, add_factor=0, task=STATIC.FASTER_TASK
        )
        # 找到best_list中最小值第一次出现的位置
        min_idx = np.argmin(best_list)
        step_lists["size_10_faster"].append(int(min_idx))

        best_list, mean_list, std_list = DataUtils.read_cost_data(
            idx, 100, scale_factor=1, add_factor=0, task=STATIC.FASTER_TASK
        )
        min_idx = np.argmin(best_list)
        step_lists["size_100_faster"].append(int(min_idx))

        best_list, mean_list, std_list = DataUtils.read_cost_data(
            idx, 10, scale_factor=1, add_factor=0, task=STATIC.NORMAL_TASK
        )
        min_idx = np.argmin(best_list)
        step_lists["size_10"].append(int(min_idx))

        best_list, mean_list, std_list = DataUtils.read_cost_data(
            idx, 100, scale_factor=1, add_factor=0, task=STATIC.NORMAL_TASK
        )
        min_idx = np.argmin(best_list)
        step_lists["size_100"].append(int(min_idx))
    result = {
        "step_lists": step_lists,
        "size_10": np.mean(step_lists["size_10"]),
        "size_100": np.mean(step_lists["size_100"]),
        "size_10_faster": np.mean(step_lists["size_10_faster"]),
        "size_100_faster": np.mean(step_lists["size_100_faster"]),
    }
    # 计算速度提升百分比
    result["size_10_speedup"] = round(
        (result["size_10"] - result["size_10_faster"]) / result["size_10"], 2
    )
    result["size_100_speedup"] = round(
        (result["size_100"] - result["size_100_faster"]) / result["size_100"], 2
    )
    print(result)

    color_1, color_2, color_3, color_4 = "#FFA07A", "#20B2AA", "#FF6347", "#4682B4"
    steps_compare(
        10, step_lists["size_10"], step_lists["size_10_faster"], show, save_dir
    )
    steps_compare(
        100, step_lists["size_100"], step_lists["size_100_faster"], show, save_dir
    )
    if save_dir is not None:
        print(f"{Fore.BLUE}saving to {save_dir}/steps_compare.json{Fore.RESET}")
        with open(f"{save_dir}/steps_compare.json", "w") as f:
            json.dump(result, f, indent=4)


def compare_with_dyn():

    light_color, green_color, bold_color = "#C4A5DE", "#20B2AA", "#B883D4"
    cur_size = 10
    # 从xlsx中读取数据
    abc_metrics_path = "output/report/ours/metrics.xlsx"
    print(f"{Fore.GREEN}Reading from {abc_metrics_path}{Fore.RESET}")
    df = pd.read_excel(abc_metrics_path)
    # 获取F1 score
    f1_scores = df["F1"].values.tolist()
    # 计算平均值，放到最后
    f1_scores.append(np.mean(f1_scores))

    # 从json中读取数据
    dyn_path = "output/report/dyn/dyn_size10_info.json"
    print(f"{Fore.GREEN}Reading from {dyn_path}{Fore.RESET}")
    with open(dyn_path, "r") as f:
        dyn_info = json.load(f)
    # 获取F1 score
    dyn_f1_scores = [_item["f1_score"] for _item in dyn_info["results"].values()]
    # 绘制F1 score的对比图,bar
    plt.figure(figsize=(10, 6))
    x = np.arange(len(f1_scores))
    width = 0.35
    plt.bar(x - width / 2, f1_scores, width, label="ABC Method", color=green_color)
    plt.bar(
        x + width / 2, dyn_f1_scores, width, label="dynGENIE3 Method", color=light_color
    )
    plt.xlabel("Experiment")
    plt.ylabel("F1 Score")
    plt.grid()
    plt.title("Comparison of F1 Scores")
    plt.xticks(x, [f"Net-{i}" for i in range(1, 6)] + ["Average"])
    plt.legend()
    plt.tight_layout()
    print(f"{Fore.BLUE}save to output/report/dyn/F1_score_compare.png{Fore.RESET}")
    plt.savefig("output/report/dyn/F1_score_compare.png")
    plt.close()

    cur_size = 100
    # 从xlsx中读取数据
    abc_metrics_path = "output/report/ours/metrics.xlsx"
    print(f"{Fore.GREEN}Reading from {abc_metrics_path}{Fore.RESET}")
    df = pd.read_excel(abc_metrics_path)
    # 获取F1 score
    f1_scores = df["F1"].values.tolist()
    # 计算平均值，放到最后
    f1_scores.append(np.mean(f1_scores))

    # 从json中读取数据
    dyn_path = "output/report/dyn/dyn_size10_info.json"
    print(f"{Fore.GREEN}Reading from {dyn_path}{Fore.RESET}")
    with open(dyn_path, "r") as f:
        dyn_info = json.load(f)
    # 获取F1 score
    dyn_f1_scores = [_item["f1_score"] for _item in dyn_info["results"].values()]
    # 绘制F1 score的对比图,bar
    plt.figure(figsize=(10, 6))
    x = np.arange(len(f1_scores))
    width = 0.35
    plt.bar(x - width / 2, f1_scores, width, label="ABC Method", color=green_color)
    plt.bar(
        x + width / 2, dyn_f1_scores, width, label="dynGENIE3 Method", color=light_color
    )
    plt.xlabel("Experiment")
    plt.ylabel("F1 Score")
    plt.grid()
    plt.title("Comparison of F1 Scores")
    plt.xticks(x, [f"Net-{i}" for i in range(1, 6)] + ["Average"])
    plt.legend()
    plt.tight_layout()
    print(f"{Fore.BLUE}save to output/report/dyn/F1_score_compare.png{Fore.RESET}")
    plt.savefig("output/report/dyn/F1_score_compare.png")
    plt.close()




if __name__ == "__main__":
    # compare_cost()
    # compare_F1()
    # compareDTW_steps(
    #     show=False,
    #     save_dir="output/report/compare/DTW",
    # )
    # compare_with_dyn()

    print("DONE~!")
